2 The model

Our interest is in modeling a sequence of scatter plots measured over time. That is, we observe

\(Y_{it}\in\mathbb R^d\) for \(i=1,\ldots,n_t\) and \(t=1,\ldots,T\).

In continuous-time flow cytometry data, we notice that this data has two properties:

  1. Each scatter plot looks approximately like a mixture of Gaussians.

  2. The general clustering structure seen in each scatter plot is slowly varying over time.

To model data like this, we wish to fit a smoothly-varying mixture of Gaussians model:

\[ Y_{it}|\{Z_{it}=k\}\sim N_d(\mu_{kt},\Sigma_{kt})\qquad\mathbb P(Z_{ik}=k)=\pi_{kt} \] where \((\mu_{kt},\Sigma_{kt},\pi_{kt})\) are slowly varying parameters.

It will be useful to have data generated from this model for testing purposes, so we begin by defining a function for simulating from this model.

2.1 Generating data from model

#' Generate data from smoothly-varying mixture of Gaussians model
#' 
#' The smoothly-varying mixture of Gaussians model is defined as follows:
#' 
#' At time t there are n_t points generated as follows:
#' 
#' Y_{it}|\{Z_{it}=k\} ~ N_d(mu_{kt},Sigma_{kt})
#' where
#' P(Z_{ik}=k)=pi_{kt}
#' and the parameters (mu_t, Sigma_t, pi_t) are all slowly varying in time.
#' 
#' This function generates Y and Z.
#' 
#' @param mu_function a function that maps a vector of times to a T-by-K-by-d
#' array of means
#' @param Sigma_function a function that maps a vector of times to a
#' T-K-by-d-by-d array of covariance matrices
#' @param pi_function a function that maps a vector of times to a T-by-K vector
#' of probabilities
#' @param num_points a T vector of integers giving the number of points n_t to
#' generate at each time point t.
#' @export
generate_smooth_gauss_mix <- function(mu_function,
                                      Sigma_function,
                                      pi_function,
                                      num_points) {
  times <- seq_along(num_points)
  mu <- mu_function(times)
  Sigma <- Sigma_function(times)
  pi <- pi_function(times)
  K <- ncol(pi) # number of components
  d <- dim(mu)[3]
  dimnames(mu) <- list(NULL, paste0("cluster", 1:K), NULL)
  
  z <- list() # z[[t]][i] = class of point i at time t
  y <- list() # y[[t]][i,] = d-vector of point i at time t
  for (t in times) {
    z[[t]] <- apply(stats::rmultinom(num_points[t], 1, pi[t, ]) == 1, 2, which)
    y[[t]] <- matrix(NA, num_points[t], d)
    for (k in 1:K) {
      ii <- z[[t]] == k # index of points in component k at time t
      if (sum(ii) == 0) next
      if (d == 1)
        y[[t]][ii, ] <- stats::rnorm(n = sum(ii),
                                                   mean = mu[t, k, ],
                                                   sd = Sigma[t, k, , ])
      else
        y[[t]][ii, ] <- mvtnorm::rmvnorm(n = sum(ii),
                                                       mean = mu[t, k, ],
                                                       sigma = Sigma[t, k, , ])
    }
  }
  list(y = y, z = z, mu = mu, Sigma = Sigma, pi = pi)
}

We have used two packages in this function, so let’s add these into our package.

usethis::use_package("stats")
usethis::use_package("mvtnorm")
## ✔ Adding 'stats' to Imports field in DESCRIPTION
## • Refer to functions with `stats::fun()`
## ✔ Adding 'mvtnorm' to Imports field in DESCRIPTION
## • Refer to functions with `mvtnorm::fun()`

Let’s generate simple examples in the \(d=1\) and \(d=3\)cases:

set.seed(123)
d <- 1; K <- 2; ntimes <- 200

ex1 <- list(
  mu_function = function(times) {
    mu <- array(NA, c(ntimes, K, d))
    mu[, , 1] <- cbind(sin(2 * pi * times / 30), 2)
    mu
  },
  Sigma_function = function(times) {
    Sigma <- array(NA, c(ntimes, K, 1, 1))
    Sigma[, , 1, 1] <- 0.25
    Sigma
  },
  pi_function = function(times) {
    pi1 <- seq(0.2, 0.8, length=length(times))
    cbind(pi1, 1 - pi1)
  },
  num_points = rep(40, ntimes)
)
ex1$dat <- generate_smooth_gauss_mix(ex1$mu_function,
                                     ex1$Sigma_function,
                                     ex1$pi_function,
                                     ex1$num_points)
d <- 3; K <- 4; ntimes <- 200

ex2 = list(
  mu_function = function(times) {
    mu <- array(NA, c(ntimes, K, d))
    mu[, , 1] <- cbind(0.5*cos(2 * pi * times / 30), 0.3*sin(2 * pi * times / 30), sin(2 * pi * times / 30), -3)
    mu[, , 2] = cbind (0.3*sin(2 * pi * times / 30), 2, -1, 0.6*cos(2 * pi * times / 30))
    mu[, , 3] = cbind(2, 0.7*cos(2 * pi * times / 30), 0.4*sin(2 * pi * times / 30), 1)
    mu
  },
  Sigma_function = function(times) {
    Sigma <- array(NA, c(ntimes, K, d, d))
    Sigma[, , 1, 1] <- 0.10
    Sigma[, , 1, 2] <- 0
    Sigma[, , 1, 3] <- 0
    Sigma[, , 2, 1] <- 0
    Sigma[, , 2, 2] <- 0.10
    Sigma[, , 2, 3] <- 0
    Sigma[, , 3, 1] <- 0
    Sigma[, , 3, 2] <- 0
    Sigma[, , 3, 3] <- 0.10
    Sigma
  },
  pi_function = function(times) {
    pi1 <- seq(0.2, 0.3, length=length(times))
    cbind(pi1, pi1, 2*pi1/3, 1- (2*pi1 + 2*pi1/3))
  },
  num_points = rep(150, ntimes)
)
ex2$dat <- generate_smooth_gauss_mix(ex2$mu_function,
                                     ex2$Sigma_function,
                                     ex2$pi_function,
                                     ex2$num_points)

2.2 Visualizing the raw data:

Let’s make a function for visualizing the data in the one-dimensional and three dimensional cases.

library(magrittr) # we'll be using the pipe in this document

The function will take as input the following argument:

#' @param y length T list with `y[[t]]` being a n_t-by-d matrix
y-param

We define this bit of documentation in its own code chunk so that it can be easily reused since multiple functions in the package take it as input.

#' Plot raw data
#' 
<<y-param>>
#' 
#' @export
plot_data <- function(y) {
  d <- ncol(y[[1]])
  if (d == 1){
  fig <- purrr::map_dfr(y, ~ tibble::tibble(y = .x), .id = "time") %>% 
  dplyr::mutate(time = as.numeric(.data$time)) %>% 
  ggplot2::ggplot(ggplot2::aes(x = .data$time, y = .data$y)) +
    ggplot2::geom_point(alpha = 0.2)
  }
  else if (d == 3){
    d <- ncol(y[[1]])
      max_val <- list()
      max_val_time <- list()
      min_val = list()
      min_val_time = list()
      for (dd in seq(d)) {
        max_val[[dd]] <- sapply(y, function(mat) max(mat[, dd]))
        max_val_time[[dd]] <- max(max_val[[dd]])
        min_val[[dd]] <- sapply(y, function(mat) min(mat[, dd]))
        min_val_time[[dd]] <- min(min_val[[dd]])
      }
      y = unname(y)
      y <- y %>% 
        purrr::map_dfr(~ tibble::tibble(x = .x[, 1], y = .x[, 2], z = .x[, 3]), .id = "time")
      y$time = as.integer(y$time)
      
      fig <- plotly::plot_ly(
        data = y,
        x = ~x, y = ~y, z = ~z,
        type = "scatter3d", frame = ~time, mode = "markers", size = 80,
        colors = colorRamp(c("blue", "red", "purple", "cyan", "magenta", "brown", "gray", "darkgreen", "darkblue", "darkred", "darkorange"))) %>%
        plotly::layout(title = 'Raw Data', scene = list(
          xaxis = list(range = c(1.1 * min_val_time[[1]], 1.1 * max_val_time[[1]]), title = 'diam_mid'), 
          yaxis = list(range = c(1.1 * min_val_time[[2]], 1.1 * max_val_time[[2]]), title = 'Chl_small'),
          zaxis = list(range = c(1.1 * min_val_time[[3]], 1.1 * max_val_time[[3]]), title = 'pe'),  
          aspectmode = "manual", 
          aspectratio = list(x = 1, y = 1, z = 1)  # Specify the fixed aspect ratio
        ))
  }
  return(fig)
}

We’ve used some functions from other packages, so let’s include those in our package:

usethis::use_pipe()
usethis::use_package("purrr")
usethis::use_package("tibble")
usethis::use_package("dplyr")
usethis::use_package("ggplot2")
usethis::use_import_from("rlang", ".data")
usethis::use_package("plotly")
usethis::use_package("grDevices")
## ✔ Adding 'magrittr' to Imports field in DESCRIPTION
## ✔ Writing 'R/utils-pipe.R'
## • Run `devtools::document()` to update 'NAMESPACE'
## ✔ Adding 'purrr' to Imports field in DESCRIPTION
## • Refer to functions with `purrr::fun()`
## ✔ Adding 'tibble' to Imports field in DESCRIPTION
## • Refer to functions with `tibble::fun()`
## ✔ Adding 'dplyr' to Imports field in DESCRIPTION
## • Refer to functions with `dplyr::fun()`
## ✔ Adding 'ggplot2' to Imports field in DESCRIPTION
## • Refer to functions with `ggplot2::fun()`
## ✔ Adding 'rlang' to Imports field in DESCRIPTION
## ✔ Adding '@importFrom rlang .data' to 'R/flowkernel-package.R'
## ✔ Writing 'NAMESPACE'
## ✔ Adding 'plotly' to Imports field in DESCRIPTION
## • Refer to functions with `plotly::fun()`
## ✔ Adding 'grDevices' to Imports field in DESCRIPTION
## • Refer to functions with `grDevices::fun()`

Let’s look at our two examples using this plotting function:

plot_data(ex1$dat$y)
plot_data(ex2$dat$y)

2.3 Visualizing data and model

We’ll also want a function for plotting the data with points colored by true (or estimated) cluster. And it will be convenient to also be able to superimpose the true (or estimated) means. The next function does this:

#' Plot data colored by cluster assignment with cluster means
#' 
<<y-param>>
#' @param z a length T list with `z[[t]]` being a n_t vector of cluster assignments
#' @param mu a T-by-K-by-d array of means
#' @export
plot_data_and_model <- function(y, z, mu) {
  d <- ncol(y[[1]])
  K <- ncol(mu)
  ntimes = length(z)
  if (d == 1){
  dat_df <- purrr::map2_dfr(z, y, ~ tibble::tibble(z = as.factor(.x), y = .y),
                     .id = "time") %>%
    dplyr::mutate(time = as.numeric(.data$time))
  means_df <- tibble::as_tibble(mu[, , 1]) %>%
    dplyr::mutate(time = dplyr::row_number()) %>%
    tidyr::pivot_longer(-.data$time, names_to = "cluster", values_to = "mean")
  fig <- ggplot2::ggplot() +
    ggplot2::geom_point(
      data = dat_df,
      ggplot2::aes(x = .data$time, y = .data$y, color = .data$z), alpha = 0.2
    ) +
    ggplot2::geom_line(
      data = means_df,
      ggplot2::aes(x = .data$time, y = .data$mean, group = .data$cluster)
    ) +
    ggplot2::labs(x = "Time", y = "Cell Diameter")  # Label the x-axis and y-axis
  }
  else if (d == 3) {
  K <- ncol(mu)
  d <- ncol(y[[1]])
  z_dat <- unlist(z)
  ntimes = length(z)
  max_val <- list()
  max_val_time <- list()
  min_val = list()
  min_val_time = list()
  for (dd in seq(d)) {
    max_val[[dd]] <- sapply(y, function(mat) max(mat[, dd]))
    max_val_time[[dd]] <- max(max_val[[dd]])
    min_val[[dd]] <- sapply(y, function(mat) min(mat[, dd]))
    min_val_time[[dd]] <- min(min_val[[dd]])
  }
  y = unname(y)
  y <- y %>% 
    purrr::map_dfr(~ tibble::tibble(x = .x[, 1], y = .x[, 2], z = .x[, 3]), .id = "time") %>%
    dplyr::mutate(z1 = z_dat)
  y$time = as.integer(y$time)

  cluster_data_frames <- vector("list", length = K)
  for (kk in seq(K)) {
    cluster_mean <- mu[, kk, ]
    data <- data.frame(
      X1 = cluster_mean [, 1],
      X2 = cluster_mean [, 2],
      X3 = cluster_mean [, 3],
      time = 1:ntimes
    )
    cluster_data_frames[[kk]] = data
  }
  fig <- y %>% plotly::plot_ly(
    x = ~x, y = ~y, z = ~z, color = ~z1,
    type = "scatter3d", frame = ~time, mode = "markers", size = 80,
    colors = colorRamp(c("blue", "orange", "red"))) %>%
    plotly::layout(scene = list(
      xaxis = list(title = "Diameter", range = c(1.1* min_val_time[[1]], 1.1 *max_val_time[[1]])),
      yaxis = list(title = "Chl_Small", range = c(1.1* min_val_time[[2]], 1.1 *max_val_time[[2]])),
      zaxis = list(title = "PE", range = c(1.1* min_val_time[[3]], 1.1 *max_val_time[[3]])),
      aspectmode = "manual",  # Set aspect ratio to manual
      aspectratio = list(x = 1, y = 1, z = 1)  # Specify the fixed aspect ratio
    ))
  updatemenus <- list(
    list(
      active = 0,
      type= 'buttons',
      buttons = list(
        list(
          label = "Data Points",
          method = "update",
          args = list(list(visible = c(TRUE, rep(c(TRUE, TRUE), K))))),
        list(
          label = "No Data Points",
          method = "update",
          args = list(list(visible = c(FALSE, rep(c(TRUE, TRUE), K))))))
    )
  )
  for (kk in seq(K)) {
    fig <- fig %>%
        plotly::add_markers(data = cluster_data_frames[[kk]], x = ~X1, y = ~X2, z = ~X3,
                    color = kk, size = 120, frame = ~time)%>%
      plotly::layout(updatemenus = updatemenus)
  }
  }
  return(fig)
}

We used a function from tidyr, so let’s include this package:

usethis::use_package("tidyr")
## ✔ Adding 'tidyr' to Imports field in DESCRIPTION
## • Refer to functions with `tidyr::fun()`

For now we can use this to visualize the true model, although later this will be useful for visualizing the estimated model.

plot_data_and_model(ex1$dat$y, ex1$dat$z, ex1$dat$mu)
plot_data_and_model(ex2$dat$y, ex2$dat$z, ex2$dat$mu)

In 3-d, with a large number of data points, these plots might become too crowded to really appreciate. We therefore have a function that just shows how the cluster centers \(\mu_{kt}\) evolve with time:

#' Plot cluster centers in 3-d
#' 
<<y-param>>
#' @param z a length T list with `z[[t]]` being a n_t vector of cluster assignments
#' @param mu a T-by-K-by-d array of means
#' @export
plot_3d_centers <- function(y, z, mu){
  K <- ncol(mu)
  d <- ncol(y[[1]])
  z_dat <- unlist(z)
  ntimes = length(z)
  max_val <- list()
  max_val_time <- list()
  min_val = list()
  min_val_time = list()
  for (dd in seq(d)) {
    max_val[[dd]] <- sapply(y, function(mat) max(mat[, dd]))
    max_val_time[[dd]] <- max(max_val[[dd]])
    min_val[[dd]] <- sapply(y, function(mat) min(mat[, dd]))
    min_val_time[[dd]] <- min(min_val[[dd]])
  }
  y = unname(y)
  y <- y %>% 
    purrr::map_dfr(~ tibble::tibble(x = .x[, 1], y = .x[, 2], z = .x[, 3]), .id = "time") %>%
    dplyr::mutate(z1 = z_dat)
  y$time = as.integer(y$time)
  
  cluster_data_frames <- vector("list", length = K)
  for (kk in seq(K)) {
    cluster_mean <- mu[, kk, ]
    data <- data.frame(
      X1 = cluster_mean [, 1],
      X2 = cluster_mean [, 2],
      X3 = cluster_mean [, 3],
      time = 1:ntimes
    )
    cluster_data_frames[[kk]] = data
  }
  fig <- plotly::plot_ly()
  for (kk in seq(K)) {
    fig <- fig %>%
      plotly::add_markers(data = cluster_data_frames[[kk]], x = ~X1, y = ~X2, z = ~X3,
                  color = kk, size = 120, frame = ~time)
  }
  
  return(fig)
}

Let’s try this out for our 3-d example:

plot_3d_centers(ex2$dat$y, ex2$dat$z, ex2$dat$mu)

There are several other ways of looking at our data and model that might be useful. We would like to see how the \(\pi_{kt}\)’s evolve with time, how a 1-d projection of the 3-d cluster means evolve with time, and how the biomass of a particular cluster evolves over time. Let’s add these functions to our package.

#' Plot cluster populations (pi) over time
#' @param pi A T-by-K array, with each row consisting of probabilities that sum to one
#' @export
plot_pi <- function(pi) {
  # Create an empty data frame
  df <- data.frame(time = seq_along(pi[, 1]))
  
  # Use a for loop to append each column to the data frame
  for (k in 1:ncol(pi)) {
    col_name <- paste("Cluster", k)
    df[[col_name]] <- pi[, k]
  }
  
  # Create the ggplot with multiple line plots
  pi_plt <- ggplot2::ggplot(df, ggplot2::aes(x = time)) +
    lapply(1:ncol(pi), function(k) {
      ggplot2::geom_line(ggplot2::aes(y = df[, k + 1], color = paste("Cluster", k)), linetype = "solid")
    }) +
    ggplot2::labs(x = "Time", y = "Pi") +
    ggplot2::ggtitle("Pi Over Time") +
    ggplot2::scale_color_manual(name = "Cluster", values = rainbow(ncol(pi)))
  
  # Convert ggplot to plotly for interactivity
  pi_plotly <- plotly::ggplotly(pi_plt, dynamicTicks = TRUE)
  
  return(pi_plotly)
}
plot_pi(ex2$dat$pi)
#' Plot cluster centers as 1-d projection over time
#' @param mu a T-by-K-by-d array of means
#' @param dim specify which dimension to be plotted: 1 (diam), 2 (chl_small), or 3 (pe)
#' @export
plot_1d_means <- function(mu, dim = 1) {
  # Create an empty data frame
  df <- data.frame(time = seq_along(mu[, 1, 1]))
  
  # Use a for loop to append each column to the data frame
  for (k in 1:ncol(mu)) {
    col_name <- paste("Cluster", k)
    df[[col_name]] <- mu[, k, dim]
  }
  # Determine the y-axis label based on the dimension
  y_label <- switch(dim,
                    "1" = "Diameter",
                    "2" = "chl_small",
                    "3" = "pe",
                    paste("Dimension", dim))
  
  # Create the ggplot with multiple line plots
  pi_plt <- ggplot2::ggplot(df, ggplot2::aes(x = time)) +
    lapply(1:ncol(mu), function(k) {
      ggplot2::geom_line(ggplot2::aes(y = df[, k + 1], color = paste("Cluster", k)), linetype = "solid")
    }) +
    ggplot2::labs(x = "Time", y = y_label) +
    ggplot2::ggtitle("Cluster Means Over Time") +
    ggplot2::scale_color_manual(name = "Cluster", values = rainbow(ncol(mu)))
  
  # Convert ggplot to plotly for interactivity
  pi_plotly <- plotly::ggplotly(pi_plt, dynamicTicks = TRUE)
  
  return(pi_plotly)
}
plot_1d_means(ex2$dat$mu, dim = 1)
plot_1d_means(ex2$dat$mu, dim = 2)
plot_1d_means(ex2$dat$mu, dim = 3)

If we want to plot all three dimensions together, we can use the following function:

#' Plot cluster means as 1-d projection over time, with all three dimensions plotted together, in separate plots
#' @param mu a T-by-K-by-d array of means
#' @export
plot_1d_means_triple <- function(mu) {
  # Create an empty data frame for the first plot
  df1 <- data.frame(time = seq_along(mu[, 1, 1]))
  
  # Use a for loop to append each column to the data frame
  for (k in 1:ncol(mu)) {
    col_name <- paste("Cluster", k)
    df1[[col_name]] <- mu[, k, 1]
  }
  
  # Create the ggplot with multiple line plots for the first plot
  pi_plt1 <- ggplot2::ggplot(df1, ggplot2::aes(x = time)) +
    lapply(1:ncol(mu), function(k) {
      ggplot2::geom_line(ggplot2::aes(y = df1[, k + 1], color = paste("Cluster", k)), linetype = "solid")
    }) +
    ggplot2::labs(x = "Time", y = "Diameter") +
    ggplot2::ggtitle("Means of Diameter Over Time") +
    ggplot2::scale_color_manual(name = "Cluster", values = rainbow(ncol(mu))) +
    ggplot2::guides(size = "none")  # To remove the size legend
  
  # Create an empty data frame for the second plot
  df2 <- data.frame(time = seq_along(mu[, 1, 1]))
  
  # Use a for loop to append each column to the data frame for the second plot
  for (k in 1:ncol(mu)) {
    col_name <- paste("Cluster", k)
    df2[[col_name]] <- mu[, k, 2]
  }
  
  # Create the ggplot with multiple line plots for the second plot
  pi_plt2 <- ggplot2::ggplot(df2, ggplot2::aes(x = time)) +
    lapply(1:ncol(mu), function(k) {
      ggplot2::geom_line(ggplot2::aes(y = df2[, k + 1], color = paste("Cluster", k)), linetype = "solid")
    }) +
    ggplot2::labs(x = "Time", y = "chl_small") +
    ggplot2::ggtitle("Means of chl_small Over Time") +
    ggplot2::scale_color_manual(name = "Cluster", values = rainbow(ncol(mu))) +
    ggplot2::guides(size = "none")  # To remove the size legend
  
  # Create an empty data frame for the third plot
  df3 <- data.frame(time = seq_along(mu[, 1, 1]))
  # Use a for loop to append each column to the data frame for the third plot
  for (k in 1:ncol(mu)) {
    col_name <- paste("Cluster", k)
    df3[[col_name]] <- mu[, k, 3]
  }
  
  # Create the ggplot with multiple line plots for the third plot
  pi_plt3 <- ggplot2::ggplot(df3, ggplot2::aes(x = time)) +
    lapply(1:ncol(mu), function(k) {
      ggplot2::geom_line(ggplot2::aes(y = df3[, k + 1], color = paste("Cluster", k)), linetype = "solid")
    }) +
    ggplot2::labs(x = "Time", y = "PE") +
    ggplot2::ggtitle("Means of PE Over Time") +
    ggplot2::scale_color_manual(name = "Cluster", values = rainbow(ncol(mu))) +
    ggplot2::guides(size = "none")  # To remove the size legend
  
  # Arrange the three plots vertically
  combined_plot <- gridExtra::grid.arrange(pi_plt1, pi_plt2, pi_plt3, ncol = 1)
  
  return(combined_plot)
}

Here is an example of running this function:

## TableGrob (3 x 1) "arrange": 3 grobs
##   z     cells    name           grob
## 1 1 (1-1,1-1) arrange gtable[layout]
## 2 2 (2-2,1-1) arrange gtable[layout]
## 3 3 (3-3,1-1) arrange gtable[layout]
usethis::use_package("gridExtra")
## ✔ Adding 'gridExtra' to Imports field in DESCRIPTION
## • Refer to functions with `gridExtra::fun()`

We now add two more plotting functions to our package: one that plots the 1-d means as above, but with the width of each line varying according to pi of each cluster, and another to plot the total biomass over time for each cluster.

#' Plot cluster means as 1-d projection over time, with line widths determined by pi
#' @param mu a T-by-K-by-d array of means
#' @param pi A T-by-K array, with each row consisting of probabilities that sum to one
#' @export
plot_1d_means_with_width <- function(mu, pi, dim = 1) {
  # Create an empty data frame
  df <- data.frame(time = seq_along(pi[, 1]))
  
  # Use a for loop to append each column to the data frame
  for (k in 1:ncol(mu)) {
    col_name <- paste("Cluster", k)
    df[[col_name]] <- mu[, k, dim]
  }
  
  y_label <- switch(dim,
                    "1" = "Diameter",
                    "2" = "chl_small",
                    "3" = "pe",
                    paste("Dimension", dim))
  
  # Create the ggplot with multiple line plots
  pi_plt <- ggplot2::ggplot(df, ggplot2::aes(x = time)) +
    lapply(1:ncol(mu), function(k) {
      ggplot2::geom_line(ggplot2::aes(y = df[, k + 1], color = paste("Cluster", k), linewidth = pi[, k]), linetype = "solid")
    }) +
    ggplot2::labs(x = "Time", y = y_label) +
    ggplot2::ggtitle("Cluster Means Over Time") +
    ggplot2::scale_color_manual(name = "Cluster", values = rainbow(ncol(mu))) +
    ggplot2::guides(linewidth = "none")  # To remove the linewidth legend
  
  return(pi_plt)
}
plot_1d_means_with_width(ex2$dat$mu, ex2$dat$pi, dim = 3)

Let’s also create a function to plot the biomass of each cluster over time. We will run this function later when we have responsibilities defined:

#' Plot biomass over time for each cluster
#' @param biomass A list of length T, where each element `biomass[[t]]` is a numeric vector of length n_t containing the biomass (or count) of particles in each bin
#' @param resp length T list with `y[[t]]` being a n_t-by-K matrix
#' @export
plot_biomass = function (biomass, resp){
  K <- ncol(resp[[1]])
  ntimes = length(resp)
  df<- data.frame(time = seq_along(resp))
  cluster_biomass <- matrix(NA, nrow = ntimes, ncol = K)

  for (tt in 1:ntimes){
    cluster_biomass [tt, ] <- sapply(seq_len(K), function(i) sum(resp[[tt]][, i] * biomass[[tt]]))
  }

  # Use a for loop to append each column to the data frame
  for (k in 1:K) {
    col_name <- paste("Cluster", k)
    df[[col_name]] <- cluster_biomass[, k]
  }
  
  # Create the ggplot with multiple line plots
  pi_plt <- ggplot2::ggplot(df, ggplot2::aes(x = time)) +
    lapply(1:K, function(k) {
      ggplot2::geom_line(ggplot2::aes(y = df[, k + 1], color = paste("Cluster", k)), linetype = "solid")
    }) +
    ggplot2::labs(x = "Time", y = "Cluster Biomass") +
    ggplot2::ggtitle("Cluster Biomass Over Time") #+
  
  # Convert ggplot to plotly for interactivity
  pi_plotly <- plotly::ggplotly(pi_plt, dynamicTicks = TRUE)
  
  return(pi_plotly)
}